"""
Phase 1: Data Assembly — Net vs Gross External Adjustment
==========================================================
Merges multilateral full_panel, extensions, and bilateral aggregates.
Constructs net position variables, lags, and sample flags.

Output: net_gross/data/processed/net_gross_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
EXTENSIONS_DIR = ROOT_DIR / "extensions"
GRAVITY_DIR = ROOT_DIR / "gravity_bilateral"

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

DATA_OUT = PROJECT_DIR / "data" / "processed"
DATA_OUT.mkdir(parents=True, exist_ok=True)
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

FINANCIAL_CENTERS = ['LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL']

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]


def main():
    print("=" * 70)
    print("PHASE 1: DATA ASSEMBLY — NET VS GROSS")
    print("=" * 70)

    # ── 1. Load multilateral full panel ────────────────────────────────
    print("\n[1] Loading full_panel.csv ...")
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    fp = fp[fp['year'] <= 2024].copy()
    print(f"  Full panel: {len(fp)} obs, {fp['iso3'].nunique()} countries")

    # ── 2. Merge extensions (trade/income balance) ─────────────────────
    print("\n[2] Merging extensions panel ...")
    ext_path = EXTENSIONS_DIR / "data" / "processed" / "extensions_panel.csv"
    if ext_path.exists():
        ext = pd.read_csv(ext_path)
        ext = ext[ext['year'] <= 2024].copy()
        # Only take extension-specific columns
        ext_cols = ['iso3', 'year', 'trade_balance_gdp', 'income_balance_gdp',
                    'savings_gdp', 'investment_gdp']
        ext_cols = [c for c in ext_cols if c in ext.columns]
        ext_sub = ext[ext_cols].drop_duplicates(subset=['iso3', 'year'])
        fp = fp.merge(ext_sub, on=['iso3', 'year'], how='left')
        print(f"  After extensions merge: {len(fp)} obs")
        print(f"  trade_balance_gdp coverage: {fp['trade_balance_gdp'].notna().sum()}")
        print(f"  income_balance_gdp coverage: {fp['income_balance_gdp'].notna().sum()}")
    else:
        print("  WARNING: extensions_panel.csv not found")

    # Fallback: compute income_balance if not available
    if 'income_balance_gdp' not in fp.columns or fp['income_balance_gdp'].notna().sum() < 100:
        if 'trade_balance_gdp' in fp.columns:
            fp['income_balance_gdp'] = fp['ca_gdp'] - fp['trade_balance_gdp']
            print("  Computed income_balance_gdp = ca_gdp - trade_balance_gdp")

    # ── 3. Aggregate bilateral positions ───────────────────────────────
    print("\n[3] Aggregating bilateral panel ...")
    bil_path = GRAVITY_DIR / "data" / "processed" / "bilateral_panel.csv"
    if bil_path.exists():
        bil = pd.read_csv(bil_path)
        bil = bil[bil['year'] <= 2024].copy()

        # Aggregate raw USD bilateral positions by reporter-year
        # Then normalize by reporter GDP to get proper GDP shares
        raw_cols = {'portfolio_total': 'agg_portfolio_total',
                    'portfolio_debt': 'agg_portfolio_debt',
                    'fdi_outward': 'agg_fdi_outward'}
        agg_specs = {'n_partners': ('partner', 'nunique'),
                     'ngdp_usd_i': ('ngdp_usd_i', 'first')}  # reporter GDP (billions)
        for raw_col, agg_name in raw_cols.items():
            if raw_col in bil.columns:
                agg_specs[agg_name] = (raw_col, 'sum')

        agg = bil.groupby(['reporter', 'year']).agg(**agg_specs).reset_index()

        # Normalize by reporter GDP (ngdp_usd_i is in billions, raw flows in USD)
        for agg_name in raw_cols.values():
            if agg_name in agg.columns:
                gdp_col = f'{agg_name}_gdp'
                agg[gdp_col] = agg[agg_name] / (agg['ngdp_usd_i'] * 1e9) * 100  # % of GDP
                agg.loc[agg[gdp_col] == 0, gdp_col] = np.nan
                agg.drop(columns=[agg_name], inplace=True)

        agg.drop(columns=['ngdp_usd_i'], inplace=True, errors='ignore')
        agg = agg.rename(columns={'reporter': 'iso3'})
        fp = fp.merge(agg, on=['iso3', 'year'], how='left')

        gdp_cols = [c for c in agg.columns if c.endswith('_gdp') and c != 'iso3']
        for c in gdp_cols:
            if c in fp.columns:
                n = fp[c].notna().sum()
                print(f"  {c} coverage: {n} obs")
        print(f"  Bilateral reporters: {agg['iso3'].nunique()}")
    else:
        print("  WARNING: bilateral_panel.csv not found")

    # ── 4. Construct variables ─────────────────────────────────────────
    print("\n[4] Constructing variables ...")

    # Net positions by instrument
    fp['fdi_net_gdp'] = fp['fdi_assets_gdp'] - fp['fdi_liab_gdp']
    fp['debt_net_gdp'] = fp['debt_assets_gdp'] - fp['debt_liab_gdp']

    # Year-on-year changes (flows proxy)
    fp = fp.sort_values(['iso3', 'year'])
    for var in ['gross_assets_gdp', 'gross_liab_gdp', 'nfa_gdp']:
        fp[f'd_{var}'] = fp.groupby('iso3')[var].diff()

    # 5-year lagged demographics
    for z in ['Z_1', 'Z_2', 'Z_3']:
        fp[f'{z}_lag5'] = fp.groupby('iso3')[z].shift(5)

    # Income terciles (based on GDP per capita)
    if 'gdp_pc_ppp' in fp.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x.dropna(), 3, labels=['Low', 'Mid', 'High'])
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        fp['income_tercile'] = fp.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)
    else:
        fp['income_tercile'] = np.nan

    # Financial center flag
    fp['financial_center'] = fp['iso3'].isin(FINANCIAL_CENTERS).astype(int)

    # OECD flag
    fp['is_oecd'] = fp['iso3'].isin(OECD).astype(int)

    # Savings-investment gap (use extensions savings_gdp if available, else gross_national_savings)
    if 'savings_investment_gap' not in fp.columns:
        if 'savings_gdp' in fp.columns and 'investment_gdp' in fp.columns:
            fp['savings_investment_gap'] = fp['savings_gdp'] - fp['investment_gdp']
        elif 'gross_national_savings_gdp' in fp.columns and 'gross_investment_gdp' in fp.columns:
            fp['savings_investment_gap'] = fp['gross_national_savings_gdp'] - fp['gross_investment_gdp']

    # Winsorized gross positions (p1/p99)
    gross_vars = ['gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi',
                  'fdi_assets_gdp', 'fdi_liab_gdp', 'port_eq_assets_gdp',
                  'debt_assets_gdp', 'debt_liab_gdp', 'fx_reserves_gdp']
    for var in gross_vars:
        if var in fp.columns:
            p1 = fp[var].quantile(0.01)
            p99 = fp[var].quantile(0.99)
            fp[f'{var}_w'] = fp[var].clip(p1, p99)
    print(f"  Winsorized {len([v for v in gross_vars if v in fp.columns])} gross position variables at p1/p99")

    # ── 5. Save panel ──────────────────────────────────────────────────
    print("\n[5] Saving panel ...")
    fp.to_csv(DATA_OUT / "net_gross_panel.csv", index=False)
    print(f"  Saved: {DATA_OUT / 'net_gross_panel.csv'}")
    print(f"  Shape: {fp.shape}")
    print(f"  Countries: {fp['iso3'].nunique()}")
    print(f"  Year range: {fp['year'].min()}-{fp['year'].max()}")

    # ── 6. Summary statistics ──────────────────────────────────────────
    print("\n[6] Summary statistics ...")
    key_vars = [
        'ca_gdp', 'trade_balance_gdp', 'income_balance_gdp',
        'gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi', 'nfa_gdp',
        'fdi_assets_gdp', 'fdi_liab_gdp', 'fdi_net_gdp',
        'debt_assets_gdp', 'debt_liab_gdp', 'debt_net_gdp',
        'port_eq_assets_gdp', 'fx_reserves_gdp',
        'savings_investment_gap', 'gross_national_savings_gdp', 'gross_investment_gdp',
        'Z_1', 'Z_2', 'Z_3', 'kaopen',
        'agg_portfolio_total_gdp', 'agg_portfolio_debt_gdp', 'agg_fdi_outward_gdp',
    ]
    key_vars = [v for v in key_vars if v in fp.columns]

    stats = fp[key_vars].describe().T[['count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]
    stats['count'] = stats['count'].astype(int)

    lines = ["# Summary Statistics: Net vs Gross External Positions\n"]
    lines.append("| Variable | N | Mean | SD | Min | P25 | P50 | P75 | Max |")
    lines.append("|:---|---:|---:|---:|---:|---:|---:|---:|---:|")
    for var, row in stats.iterrows():
        lines.append(f"| {var} | {row['count']} | {row['mean']:.3f} | {row['std']:.3f} | "
                     f"{row['min']:.3f} | {row['25%']:.3f} | {row['50%']:.3f} | "
                     f"{row['75%']:.3f} | {row['max']:.3f} |")
    lines.append("\n*Source: Full panel filtered to year ≤ 2024.*")

    (OUT_TABLES / "summary_statistics.md").write_text('\n'.join(lines))
    print(f"  Saved: {OUT_TABLES / 'summary_statistics.md'}")

    # Coverage by variable group
    print("\n  Variable coverage:")
    for v in key_vars:
        n = fp[v].notna().sum()
        nc = fp.loc[fp[v].notna(), 'iso3'].nunique()
        print(f"    {v:35s}  N={n:6d}  countries={nc:3d}")

    print("\n" + "=" * 70)
    print("PHASE 1 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
